Note
Click here to download the full example code
Align brain surfaces of 2 individuals with fMRI data#
In this example, we align 2 low-resolution left hemispheres using 4 fMRI feature maps (z-score contrast maps).
import gdist
import matplotlib as mpl
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np
from fugw.mappings import FUGW
from mpl_toolkits.axes_grid1 import make_axes_locatable
from nilearn import datasets, image, plotting, surface
Let’s download 5 volumetric contrast maps per individual
using nilearn’s API. We will use the first 4 of them
to compute an alignment between the source and target subjects,
and use the left-out contrast to assess the quality of our alignment.
n_subjects = 2
contrasts = [
"sentence reading vs checkerboard",
"sentence listening",
"calculation vs sentences",
"left vs right button press",
"checkerboard",
]
n_training_contrasts = 4
brain_data = datasets.fetch_localizer_contrasts(
contrasts,
n_subjects=n_subjects,
get_anats=True,
)
source_imgs_paths = brain_data["cmaps"][0 : len(contrasts)]
target_imgs_paths = brain_data["cmaps"][len(contrasts) : 2 * len(contrasts)]
Dataset created in /github/home/nilearn_data/brainomics_localizer
Downloading data from https://osf.io/hwbm2/download ...
...done. (3 seconds, 0 min)
Downloading data from https://osf.io/download/5d27cd441c5b4a001aa08008/ ...
...done. (3 seconds, 0 min)
Downloading data from https://osf.io/download/5d27c03e45253a001c3e189f/ ...
...done. (2 seconds, 0 min)
Downloading data from https://osf.io/download/5d27bfd0114a420016057cba/ ...
...done. (2 seconds, 0 min)
Downloading data from https://osf.io/download/5d27cb281c5b4a001aa07e29/ ...
...done. (2 seconds, 0 min)
Downloading data from https://osf.io/download/5d27cc0845253a001c3e22bd/ ...
...done. (2 seconds, 0 min)
Downloading data from https://osf.io/download/5d27d10b114a420019044ed8/ ...
...done. (2 seconds, 0 min)
Downloading data from https://osf.io/download/5d27d89d1c5b4a001d9f5e6e/ ...
...done. (2 seconds, 0 min)
Downloading data from https://osf.io/download/5d27d429a26b340017083380/ ...
...done. (2 seconds, 0 min)
Downloading data from https://osf.io/download/5d27ddc91c5b4a001b9ef9d0/ ...
...done. (3 seconds, 0 min)
Downloading data from https://osf.io/download/5d27d14f114a420019044efc/ ...
...done. (3 seconds, 0 min)
Downloading data from https://osf.io/download/5d275eb845253a001c3dbf76/ ...
Downloaded 8396800 of 14012301 bytes (59.9%, 1.3s remaining) ...done. (5 seconds, 0 min)
Downloading data from https://osf.io/download/5d275ede1c5b4a001aa00c26/ ...
Downloaded 8396800 of 13951266 bytes (60.2%, 1.2s remaining) ...done. (5 seconds, 0 min)
Downloading data from https://osf.io/download/5d27037f45253a001c3d4563/ ...
...done. (3 seconds, 0 min)
Downloading data from https://osf.io/download/5d7b8948fcbf44001c44e695/ ...
...done. (2 seconds, 0 min)
/usr/local/lib/python3.8/site-packages/nilearn/datasets/func.py:763: UserWarning: `legacy_format` will default to `False` in release 0.11. Dataset fetchers will then return pandas dataframes by default instead of recarrays.
warnings.warn(_LEGACY_FORMAT_MSG)
Here is what the first contrast map of the source subject looks like (the following figure is interactive):
contrast_index = 0
plotting.view_img(
source_imgs_paths[contrast_index],
brain_data["anats"][0],
title=f"Contrast {contrast_index} (source subject)",
opacity=0.5,
)
/usr/local/lib/python3.8/site-packages/nilearn/_utils/niimg.py:63: UserWarning: Non-finite values detected. These values will be replaced with zeros.
warn(
/usr/local/lib/python3.8/site-packages/numpy/core/fromnumeric.py:784: UserWarning: Warning: 'partition' will ignore the 'mask' of the MaskedArray.
a.partition(kth, axis=axis, kind=kind, order=order)
Computing feature arrays#
Let’s project these 4 maps to a mesh representing the cortical surface and aggregate these projections to build an array of features for the source and target subjects. For the sake of keeping the training phase of our mapping short even on CPU, we project these volumetric maps on a very low-resolution mesh made of 642 vertices.
fsaverage3 = datasets.fetch_surf_fsaverage(mesh="fsaverage3")
def load_images_and_project_to_surface(image_paths):
"""Util function for loading and projecting volumetric images."""
images = [image.load_img(img) for img in image_paths]
surface_images = [
np.nan_to_num(surface.vol_to_surf(img, fsaverage3.pial_left))
for img in images
]
return np.stack(surface_images)
source_features = load_images_and_project_to_surface(source_imgs_paths)
target_features = load_images_and_project_to_surface(target_imgs_paths)
source_features.shape
/usr/local/lib/python3.8/site-packages/nilearn/surface/surface.py:465: RuntimeWarning: Mean of empty slice
texture = np.nanmean(all_samples, axis=2)
/usr/local/lib/python3.8/site-packages/nilearn/surface/surface.py:465: RuntimeWarning: Mean of empty slice
texture = np.nanmean(all_samples, axis=2)
/usr/local/lib/python3.8/site-packages/nilearn/surface/surface.py:465: RuntimeWarning: Mean of empty slice
texture = np.nanmean(all_samples, axis=2)
/usr/local/lib/python3.8/site-packages/nilearn/surface/surface.py:465: RuntimeWarning: Mean of empty slice
texture = np.nanmean(all_samples, axis=2)
/usr/local/lib/python3.8/site-packages/nilearn/surface/surface.py:465: RuntimeWarning: Mean of empty slice
texture = np.nanmean(all_samples, axis=2)
/usr/local/lib/python3.8/site-packages/nilearn/surface/surface.py:465: RuntimeWarning: Mean of empty slice
texture = np.nanmean(all_samples, axis=2)
/usr/local/lib/python3.8/site-packages/nilearn/surface/surface.py:465: RuntimeWarning: Mean of empty slice
texture = np.nanmean(all_samples, axis=2)
/usr/local/lib/python3.8/site-packages/nilearn/surface/surface.py:465: RuntimeWarning: Mean of empty slice
texture = np.nanmean(all_samples, axis=2)
/usr/local/lib/python3.8/site-packages/nilearn/surface/surface.py:465: RuntimeWarning: Mean of empty slice
texture = np.nanmean(all_samples, axis=2)
/usr/local/lib/python3.8/site-packages/nilearn/surface/surface.py:465: RuntimeWarning: Mean of empty slice
texture = np.nanmean(all_samples, axis=2)
(5, 642)
Here is a figure showing the 4 projected maps for each of the 2 individuals:
def plot_surface_map(surface_map, cmap="coolwarm", colorbar=True, **kwargs):
"""Util function for plotting surfaces."""
plotting.plot_surf(
fsaverage3.pial_left,
surface_map,
cmap=cmap,
colorbar=colorbar,
bg_map=fsaverage3.sulc_left,
bg_on_data=True,
darkness=0.5,
**kwargs,
)
fig = plt.figure(figsize=(3 * n_subjects, 3 * len(contrasts)))
grid_spec = gridspec.GridSpec(len(contrasts), n_subjects, figure=fig)
# Print all feature maps
for i, contrast_name in enumerate(contrasts):
for j, features in enumerate([source_features, target_features]):
ax = fig.add_subplot(grid_spec[i, j], projection="3d")
plot_surface_map(
features[i, :], axes=ax, vmax=10, vmin=-10, colorbar=False
)
# Add labels to subplots
if i == 0:
for j in range(2):
ax = fig.add_subplot(grid_spec[i, j])
ax.axis("off")
ax.text(0.5, 1, f"sub-0{j}", va="center", ha="center")
ax = fig.add_subplot(grid_spec[i, :])
ax.axis("off")
ax.text(0.5, 0, contrast_name, va="center", ha="center")
# Add colorbar
ax = fig.add_subplot(grid_spec[2, :])
ax.axis("off")
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="2%")
fig.add_axes(cax)
fig.colorbar(
mpl.cm.ScalarMappable(
norm=mpl.colors.Normalize(vmin=-10, vmax=10), cmap="coolwarm"
),
cax=cax,
)
plt.show()

Computing geometry arrays#
Now we compute the kernel matrix of distances between vertices on the cortical surface. Note that in this example, we are using the same mesh for the source and target individuals, but this does not have to be the case in general.
def compute_geometry_from_mesh(mesh_path):
"""Util function to compute matrix of geodesic distances of a mesh."""
(coordinates, triangles) = surface.load_surf_mesh(mesh_path)
geometry = gdist.local_gdist_matrix(
coordinates.astype(np.float64), triangles.astype(np.int32)
).toarray()
return geometry
fsaverage3_pial_left_geometry = compute_geometry_from_mesh(
fsaverage3.pial_left
)
source_geometry = fsaverage3_pial_left_geometry
target_geometry = fsaverage3_pial_left_geometry
source_geometry.shape
(642, 642)
Each line vertex_index of the geometry matrices contains the anatomical
distance (here in millimeters) from vertex_index to all other vertices
of the mesh.
vertex_index = 4
fig = plt.figure(figsize=(5, 5))
ax = fig.add_subplot(111, projection="3d")
ax.set_title("Geodesic distance in mm\non the cortical surface")
plot_surface_map(
source_geometry[vertex_index, :],
cmap="magma",
cbar_tick_format="%.2f",
axes=ax,
)
plt.show()

Normalizing features and geometries#
Features and geometries should be normalized before we can train a mapping.
Indeed, without this scaling, it’s unclear whether the source and target
features would be comparable. Moreover, the hyper-parameter alpha would
depend on the scale of the respective matrices. Finally, it can empiracally
lead to having nan values in the computed transport plan.
source_features_normalized = source_features / np.linalg.norm(
source_features, axis=1
).reshape(-1, 1)
target_features_normalized = target_features / np.linalg.norm(
target_features, axis=1
).reshape(-1, 1)
source_geometry_normalized = source_geometry / np.max(source_geometry)
target_geometry_normalized = target_geometry / np.max(target_geometry)
Training the mapping#
Let’s create our mapping. We set alpha=0.5 to indicate that we are
as interested in matching vertices with similar features as we are in
preserving the anatomical geometries of the source and target subjects.
We leave rho to its default value, and finally set a value of eps
which is low enough for the computed transport plan to not be too entropic.
High values of eps lead to faster computations and more entropic plans.
Low values of eps lead to solwer computations, but finer-grained plans.
Note that this package is meant to be used with GPUs ; fitting mappings
on CPUs in about 100x slower.
Let’s fit our mapping! Remember to use the training maps only. Moreover, we limit the number of block-coordinate-descent iterations to 3 in order to limit computation time for this example.
_ = mapping.fit(
source_features_normalized[:n_training_contrasts],
target_features_normalized[:n_training_contrasts],
source_geometry=source_geometry_normalized,
target_geometry=target_geometry_normalized,
solver="sinkhorn",
solver_params={
"nits_bcd": 3,
},
verbose=True,
)
[21:02:45] BCD step 1/3 FUGW loss: 0.028870292007923126 dense.py:360
(base) 0.02978154458105564 (entropic)
[21:03:19] BCD step 2/3 FUGW loss: 0.004489848855882883 dense.py:360
(base) 0.005622141994535923 (entropic)
[21:03:52] BCD step 3/3 FUGW loss: 0.004457559436559677 dense.py:360
(base) 0.005594924092292786 (entropic)
Here is the evolution of the FUGW loss during training, with and without the entropic term:
fig, ax = plt.subplots(figsize=(4, 4))
ax.set_title(
"Sinkhorn mapping training loss\n"
f"Total training time = {mapping.loss_times[-1]:.1f}s"
)
ax.set_ylabel("Loss")
ax.set_xlabel("BCD step")
ax.plot(mapping.loss_steps, mapping.loss, label="FUGW loss")
ax.plot(mapping.loss_steps, mapping.loss_entropic, label="FUGW entropic loss")
ax.legend()
plt.show()

Note that we used the sinkhorn solver here because it’s well known
in the optimal transport community, but that
this library comes with other solvers which are, in most cases,
much faster.
Let’s retrain our mapping using the mm solver, which implements
a maximize-minimization approach to approximate a solution and is
used by default in fugw.mappings:
mm_mapping = FUGW(alpha=0.5, rho=1, eps=1e-4)
_ = mm_mapping.fit(
source_features_normalized[:n_training_contrasts],
target_features_normalized[:n_training_contrasts],
source_geometry=source_geometry_normalized,
target_geometry=target_geometry_normalized,
solver="mm",
solver_params={
"nits_bcd": 5,
"tol_bcd": 1e-10,
"tol_uot": 1e-10,
},
verbose=True,
)
[21:03:57] BCD step 1/5 FUGW loss: 0.03685737028717995 dense.py:360
(base) 0.036940909922122955 (entropic)
[21:04:03] BCD step 2/5 FUGW loss: 0.014586355537176132 dense.py:360
(base) 0.015035441145300865 (entropic)
[21:04:08] BCD step 3/5 FUGW loss: 0.006973237730562687 dense.py:360
(base) 0.007662989664822817 (entropic)
[21:04:15] BCD step 4/5 FUGW loss: 0.005839386489242315 dense.py:360
(base) 0.006633028853684664 (entropic)
[21:04:27] BCD step 5/5 FUGW loss: 0.005359851289540529 dense.py:360
(base) 0.006219192408025265 (entropic)
And now with the ibpp solver:
ibpp_mapping = FUGW(alpha=0.5, rho=1, eps=1e-4)
_ = ibpp_mapping.fit(
source_features_normalized[:n_training_contrasts],
target_features_normalized[:n_training_contrasts],
source_geometry=source_geometry_normalized,
target_geometry=target_geometry_normalized,
solver="ibpp",
solver_params={
"nits_bcd": 5,
"tol_bcd": 1e-10,
"tol_uot": 1e-10,
},
verbose=True,
)
[21:04:32] BCD step 1/5 FUGW loss: 0.034325726330280304 dense.py:360
(base) 0.03454320505261421 (entropic)
[21:04:38] BCD step 2/5 FUGW loss: 0.007011039648205042 dense.py:360
(base) 0.00773261021822691 (entropic)
[21:04:48] BCD step 3/5 FUGW loss: 0.005309466738253832 dense.py:360
(base) 0.006187541410326958 (entropic)
[21:05:06] BCD step 4/5 FUGW loss: 0.0049124350771307945 dense.py:360
(base) 0.005864681676030159 (entropic)
[21:05:26] BCD step 5/5 FUGW loss: 0.004740338306874037 dense.py:360
(base) 0.0057375673204660416 (entropic)
Here is the evolution of the FUGW loss during training,
without the entropic term. Note how, in this case,
even though mm and ibpp needed more block-coordinate-descent steps
to converge, they were about 2 to 3 times faster to reach the same final
FUGW training loss as sinkhorn.
You might want to tweak solver parameters like nits_bcd and nits_uot
to get the fastest convergence rates.
fig = plt.figure(figsize=(4 * 2, 4))
fig.suptitle("Training loss comparison\nSinkhorn vs MM vs IBPP")
ax = fig.add_subplot(121)
ax.set_ylabel("Loss")
ax.set_xlabel("BCD step")
ax.plot(mapping.loss_steps, mapping.loss, label="Sinkhorn FUGW loss")
ax.plot(mm_mapping.loss_steps, mm_mapping.loss, label="MM FUGW loss")
ax.plot(ibpp_mapping.loss_steps, ibpp_mapping.loss, label="IBPP FUGW loss")
ax.legend()
ax = fig.add_subplot(122)
ax.set_ylabel("Loss")
ax.set_xlabel("Time (in seconds)")
ax.plot(mapping.loss_times, mapping.loss, label="FUGW loss")
ax.plot(mm_mapping.loss_times, mm_mapping.loss, label="MM FUGW loss")
ax.plot(ibpp_mapping.loss_times, ibpp_mapping.loss, label="IBPP FUGW loss")
ax.legend()
fig.tight_layout()
plt.show()

Using the computed mapping#
The computed mapping is stored in mapping.pi as a torch.Tensor.
In this example, the transport plan is small enough that we can display
it altogether.
pi = mapping.pi.numpy()
fig, ax = plt.subplots(figsize=(10, 10))
ax.set_title("Transport plan", fontsize=20)
ax.set_xlabel("target vertices", fontsize=15)
ax.set_ylabel("source vertices", fontsize=15)
im = plt.imshow(pi, cmap="viridis")
plt.colorbar(im, ax=ax, shrink=0.8)
plt.show()

Each line vertex_index of the computed mapping can be interpreted as
a probability map describing which vertices of the target
should be mapped with the source vertex vertex_index.
probability_map = pi[vertex_index, :] / np.linalg.norm(pi[vertex_index, :])
fig = plt.figure(figsize=(5, 5))
ax = fig.add_subplot(111, projection="3d")
ax.set_title(
"Probability map of target vertices\n"
f"being matched with source vertex {vertex_index}"
)
plot_surface_map(probability_map, cmap="viridis", axes=ax)
plt.show()

Using mapping.transform(),
we can use the computed mapping to transport any collection of feature maps
from the source anatomy onto the target anatomy.
Note that, conversely, mapping.inverse_transform() takes feature maps
from the target anatomy and transports them on the source anatomy.
contrast_index = 2
predicted_target_features = mapping.transform(
source_features[contrast_index, :]
)
predicted_target_features.shape
(642,)
fig = plt.figure(figsize=(3 * 3, 3))
fig.suptitle("Transporting feature maps of the training set")
grid_spec = gridspec.GridSpec(1, 3, figure=fig)
ax = fig.add_subplot(grid_spec[0, 0], projection="3d")
ax.set_title("Actual source features")
plot_surface_map(
source_features[contrast_index, :], axes=ax, vmax=10, vmin=-10
)
ax = fig.add_subplot(grid_spec[0, 1], projection="3d")
ax.set_title("Predicted target features")
plot_surface_map(predicted_target_features, axes=ax, vmax=10, vmin=-10)
ax = fig.add_subplot(grid_spec[0, 2], projection="3d")
ax.set_title("Actual target features")
plot_surface_map(
target_features[contrast_index, :], axes=ax, vmax=10, vmin=-10
)
plt.show()

Here, we transported a feature map which is part of the traning set, which does not really help evaluate the quality of our model. Instead, we can also use the computed mapping to transport unseen data, which is how we will usually assess whether our model has captured useful information or not:
contrast_index = len(contrasts) - 1
predicted_target_features = mapping.transform(
source_features[contrast_index, :]
)
fig = plt.figure(figsize=(3 * 3, 3))
fig.suptitle("Transporting feature maps of the test set")
grid_spec = gridspec.GridSpec(1, 3, figure=fig)
ax = fig.add_subplot(grid_spec[0, 0], projection="3d")
ax.set_title("Actual source features")
plot_surface_map(
source_features[contrast_index, :], axes=ax, vmax=10, vmin=-10
)
ax = fig.add_subplot(grid_spec[0, 1], projection="3d")
ax.set_title("Predicted target features")
plot_surface_map(predicted_target_features, axes=ax, vmax=10, vmin=-10)
ax = fig.add_subplot(grid_spec[0, 2], projection="3d")
ax.set_title("Actual target features")
plot_surface_map(
target_features[contrast_index, :], axes=ax, vmax=10, vmin=-10
)
plt.show()

Total running time of the script: ( 3 minutes 58.554 seconds)
Estimated memory usage: 242 MB